{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Individual Strings Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method = 'gcg'\n",
    "logdir = f'results/'\n",
    "\n",
    "# for individual experiments\n",
    "individual = True\n",
    "mode = 'strings'\n",
    "\n",
    "files = !ls {logdir}individual_{mode}_*_ascii*\n",
    "files = [f for f in files if 'json' in f]\n",
    "files = sorted(files, key=lambda x: \"_\".join(x.split('_')[:-1]))\n",
    "\n",
    "max_examples = 100\n",
    "\n",
    "logs = []\n",
    "for logfile in files:\n",
    "    with open(logfile, 'r') as f:\n",
    "        logs.append(json.load(f))\n",
    "log = logs[0]\n",
    "len(logs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = log['params']\n",
    "print(config.keys())\n",
    "\n",
    "total_steps = config['n_steps']\n",
    "test_steps = config.get('test_steps', 50)\n",
    "log_steps = total_steps // test_steps + 1\n",
    "print('log_steps', log_steps)\n",
    "\n",
    "if individual:\n",
    "    examples = 0\n",
    "    test_logs = []\n",
    "    control_logs = []\n",
    "    goals, targets = [],[]\n",
    "    for l in logs:\n",
    "        sub_test_logs = l['tests']\n",
    "        sub_examples = len(sub_test_logs) // log_steps\n",
    "        examples += sub_examples\n",
    "        test_logs.extend(sub_test_logs[:sub_examples * log_steps])\n",
    "        control_logs.extend(l['controls'][:sub_examples * log_steps])\n",
    "        goals.extend(l['params']['goals'][:sub_examples])\n",
    "        targets.extend(l['params']['targets'][:sub_examples])\n",
    "        if examples >= max_examples:\n",
    "            break\n",
    "else:\n",
    "    test_logs = log['tests']\n",
    "    examples = 1\n",
    "\n",
    "passed, em, loss, total = [],[],[],[]\n",
    "for i in range(examples):\n",
    "    sub_passed, sub_em, sub_loss, sub_total = [],[],[],[]\n",
    "    for res in test_logs[i*log_steps:(i+1)*log_steps]:\n",
    "        sub_passed.append(res['n_passed'])\n",
    "        sub_em.append(res['n_em'])\n",
    "        sub_loss.append(res['n_loss'])\n",
    "        sub_total.append(res['total'])\n",
    "    passed.append(sub_passed)\n",
    "    em.append(sub_em)\n",
    "    loss.append(sub_loss)\n",
    "    total.append(sub_total)\n",
    "passed = np.array(passed)\n",
    "em = np.array(em)\n",
    "loss = np.array(loss)\n",
    "total = np.array(total)\n",
    "total.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "em[...,0].mean(0)[-1], loss[...,0].mean(0)[-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Individual Behaviors Results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(To get more accurate results, please run the cells below, then use `evaluate_individual.py`)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method = 'gcg'\n",
    "logdir = f'results/'\n",
    "\n",
    "# for individual experiments\n",
    "individual = True\n",
    "mode = 'behaviors'\n",
    "\n",
    "files = !ls {logdir}individual_{mode}_*_ascii*\n",
    "files = [f for f in files if 'json' in f]\n",
    "files = sorted(files, key=lambda x: \"_\".join(x.split('_')[:-1]))\n",
    "\n",
    "max_examples = 100\n",
    "\n",
    "logs = []\n",
    "for logfile in files:\n",
    "    with open(logfile, 'r') as f:\n",
    "        logs.append(json.load(f))\n",
    "log = logs[0]\n",
    "len(logs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = log['params']\n",
    "print(config.keys())\n",
    "\n",
    "total_steps = config['n_steps']\n",
    "test_steps = config.get('test_steps', 50)\n",
    "log_steps = total_steps // test_steps + 1\n",
    "print('log_steps', log_steps)\n",
    "\n",
    "if individual:\n",
    "    examples = 0\n",
    "    test_logs = []\n",
    "    control_logs = []\n",
    "    goals, targets = [],[]\n",
    "    for l in logs:\n",
    "        sub_test_logs = l['tests']\n",
    "        sub_examples = len(sub_test_logs) // log_steps\n",
    "        examples += sub_examples\n",
    "        test_logs.extend(sub_test_logs[:sub_examples * log_steps])\n",
    "        control_logs.extend(l['controls'][:sub_examples * log_steps])\n",
    "        goals.extend(l['params']['goals'][:sub_examples])\n",
    "        targets.extend(l['params']['targets'][:sub_examples])\n",
    "        if examples >= max_examples:\n",
    "            break\n",
    "else:\n",
    "    test_logs = log['tests']\n",
    "    examples = 1\n",
    "examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "passed, em, loss, total, controls = [],[],[],[],[]\n",
    "for i in range(examples):\n",
    "    sub_passed, sub_em, sub_loss, sub_total, sub_control = [],[],[],[],[]\n",
    "    for res in test_logs[i*log_steps:(i+1)*log_steps]:\n",
    "        sub_passed.append(res['n_passed'])\n",
    "        sub_em.append(res['n_em'])\n",
    "        sub_loss.append(res['n_loss'])\n",
    "        sub_total.append(res['total'])\n",
    "    sub_control = control_logs[i*log_steps:(i+1)*log_steps]\n",
    "    passed.append(sub_passed)\n",
    "    em.append(sub_em)\n",
    "    loss.append(sub_loss)\n",
    "    total.append(sub_total)\n",
    "    controls.append(sub_control)\n",
    "passed = np.array(passed)\n",
    "em = np.array(em)\n",
    "loss = np.array(loss)\n",
    "total = np.array(total)\n",
    "total.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "saved_controls = [c[-1] for c in controls]\n",
    "json_obj = {\n",
    "    'goal': goals,\n",
    "    'target': targets,\n",
    "    'controls': saved_controls\n",
    "}\n",
    "with open('results/individual_behavior_controls.json', 'w') as f:\n",
    "    json.dump(json_obj, f)\n",
    "\n",
    "# now run `evaluate_individual.py` with this file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = json.load(open('eval/individual_behavior_controls.json', 'r'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(np.array(data['Vicuna-7B']['jb']) == 1).mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Transfer Results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run `evaluate.py` on the logfile first to generate a log in the eval dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_log(log, jb, idx=-1):\n",
    "    fig, axes = plt.subplots(1, 3, figsize=(15, 3))\n",
    "\n",
    "    # Plotting the bars in the first plot\n",
    "    bars = axes[0].bar(log.keys(), jb[:, idx])\n",
    "    axes[0].xaxis.set_tick_params(rotation=90)\n",
    "    axes[0].grid(axis='y', ls='dashed')\n",
    "\n",
    "    # Plotting the lines in the second plot\n",
    "    lines = []\n",
    "    for i in range(len(log)):\n",
    "        line, = axes[1].plot(range(len(jb[0])), jb[i], label=list(log.keys())[i])\n",
    "        lines.append(line)\n",
    "\n",
    "    # Getting the handles and labels from the legend of the second plot\n",
    "    handles, labels = axes[1].get_legend_handles_labels()\n",
    "\n",
    "    # Plotting the legend in the first plot using the handles and labels from the second plot\n",
    "    axes[0].legend(handles=handles, labels=labels, bbox_to_anchor=(1.1, -0.45, 2., .102),\n",
    "                loc='lower left', ncol=4, mode=\"expand\", borderaxespad=0.)\n",
    "\n",
    "    axes[2].plot(range(len(jb[0])), jb.mean(0), color='red')\n",
    "    axes[2].set_ylim(0, 100)\n",
    "    axes[2].grid(axis='y', ls='dashed')\n",
    "\n",
    "    # Matching the colors of the bars in the first plot with the lines in the legend\n",
    "    for bar, line in zip(bars, lines):\n",
    "        bar.set_color(line.get_color())\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logdir = f'eval/'\n",
    "logfile = <your_logfile>\n",
    "\n",
    "with open(logdir + logfile, 'r') as f:\n",
    "    log = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jb, em = [],[]\n",
    "for model in log:\n",
    "    stats = log[model]\n",
    "    jb.append(stats['test_jb'])\n",
    "    em.append(stats['test_em'])\n",
    "jb = np.array(jb)\n",
    "jb = jb.mean(-1)\n",
    "em = np.array(em)\n",
    "em = em.mean(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_log(log, jb, idx=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "display",
   "language": "python",
   "name": "base"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
